#!/usr/bin/env python
# -*- coding: utf-8 -*-

""" By Martin Senande-Rivera
	For Spatial and temporal expansion of global wildland fire activity in response to climate change """

import numpy as np
import xarray as xr

### File reading ###
####################

ruta_wfde5='../../DATA/WFDE5/'	# WFDE5 data path
ds_tp_a = xr.open_dataset(ruta_wfde5+'wfde5_tp_a_mean_interp025.nc')
ds_t2m_a = xr.open_dataset(ruta_wfde5+'wfde5_t2m_a_mean_interp025.nc')
ds_tp_m = xr.open_dataset(ruta_wfde5+'wfde5_tp_m_mean_interp025.nc')
ds_t2m_m = xr.open_dataset(ruta_wfde5+'wfde5_t2m_m_mean_interp025.nc')
MAP = ds_tp_a.tp	# Total Precipitation annal mean
MAT = ds_t2m_a.t2m	# 2m Temperature annal mean
Pm = ds_tp_m.tp		# Total Precipitation monthly means
Tm = ds_t2m_m.t2m	# 2m Temperature monthly means

### Variable calculation ###
############################

# Precipitation and Tempearture in Oct,Nov,Dec,Jan,Feb,Mar (ONDJFM) and Apr,May,Jun,Jul,Aug,Sep (AMJJAS) 
ONDJFM = [10,11,12,1,2,3]
AMJJAS = [4,5,6,7,8,9]
P_ONDJFM = Pm.sel(month=ONDJFM).sum(dim='month')
P_AMJJAS = Pm.sel(month=AMJJAS).sum(dim='month')
T_ONDJFM = Tm.sel(month=ONDJFM).mean(dim='month')
T_AMJJAS = Tm.sel(month=AMJJAS).mean(dim='month')

# Precipitation in winter (P_winter) and summer (P_summer)
P_winter = xr.full_like(MAP, fill_value=0)
P_summer = xr.full_like(MAP, fill_value=0)
P_winter = xr.where(T_AMJJAS<T_ONDJFM,P_AMJJAS,P_winter)
P_winter = xr.where(T_ONDJFM<T_AMJJAS,P_ONDJFM,P_winter)
P_summer = xr.where(T_AMJJAS>T_ONDJFM,P_AMJJAS,P_summer)
P_summer = xr.where(T_ONDJFM>T_AMJJAS,P_ONDJFM,P_summer)

# Precipitation threshold (P_threshold)
P_treshold = 2*MAT+14.
P_treshold = xr.where(P_winter>MAP*0.7,2*MAT,P_treshold)
P_treshold = xr.where(P_summer>MAP*0.7,2*MAT+28.,P_treshold)

# Temperature of the coldest month (T_cold) and hottest month (T_hot)
T_cold = Tm.min(dim='month')
T_hot = Tm.max(dim='month')

### KG classification ###
#########################

KG = xr.full_like(MAP, fill_value=0)
KG = xr.where((MAP>=P_treshold*10.) & (T_cold>=18.),1,KG)				# Tropical
KG = xr.where((MAP<P_treshold*10.),2,KG)								# Arid
KG = xr.where((MAP>=P_treshold*10.) & (T_cold<18.) & (MAT>=2.),3,KG)	# Temperate
KG = xr.where((MAP>=P_treshold*10.) & (MAT<2.),4,KG)					# Cold
KG = xr.where((MAP>=P_treshold*10.) & (T_hot<=10.),5,KG)				# Polar

### Save files ###
##################
KG.name='KG'
KG.to_netcdf('KG_class.nc')
